from hpo.hpo_base import *
import ConfigSpace as CS
import torch
import numpy as np
import sys
import math
import gpytorch
from hpo.utils import train_gp, copula_standardize, interleaved_search, get_reward_from_trajectory, grad_search
from copy import deepcopy
import time
import pandas as pd
import logging
import pickle
import os
from typing import Callable

# Some env variables
MAX_CHOLESKY_SIZE = 2000
MIN_CUDA = 1024
DEVICE = 'cpu'


class Casmo4RL(HyperparameterOptimizer):
    """
    Casmopolitan for RL

    In addition to the original Casmopolitan, this implementation also supports wrapping of the conditional search space
    based on
        Lévesque, Julien-Charles, et al. "Bayesian optimization for conditional hyperparameter spaces."
        2017 International Joint Conference on Neural Networks (IJCNN). IEEE, 2017.

    and integer-search space wrapping:
        Garrido-Merchán and Hernández-Lobato. "Dealing with Categorical and Integer-valued Variables in Bayesian
         Optimization with Gaussian Processes". Neurocomputing, 2020

    """

    def __init__(self, env, log_dir,
                 max_iters: int,
                 max_timesteps: int = None,
                 batch_size: int = 1,
                 n_init: int = None,
                 verbose: bool = True,
                 ard='auto',
                 multi_fidelity: bool = False,
                 use_reward: float = 0.,
                 log_interval: int = 1,
                 previous_trainer=None,
                 time_varying=False,
                 current_timestep: int = 0,
                 acq: str = 'lcb',
                 obj_func: Callable = None,
                 seed: int = None,
                 use_standard_gp: bool = False,
                 ):
        super().__init__(env, max_iters, batch_size, 1)
        # self.dummy = dummy
        self.max_timesteps = max_timesteps if max_timesteps is not None else env.default_env_args['num_timesteps']
        # check whether we need to do mixed optimization by inspecting whether there are any continuous dims.
        self.log_dir = log_dir
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        self.verbose = verbose
        # self.max_iters *= n_repetitions
        self.cur_iters = 0
        self.dim = len(env.config_space.get_hyperparameters())
        self.log_interval = log_interval
        self.n_init = n_init if n_init is not None and n_init > 0 else min(10, 2 * self.dim + 1)
        self.previous_trainer = deepcopy(previous_trainer)

        # settings related to the time-varying GP
        self.time_varying = time_varying
        self.current_timestep = current_timestep
        self.use_standard_gp = use_standard_gp

        # if previous_surrogate is not None:
        #     logging.info('With prior_surrogate specified, random initialisation halved.')
        #     self.n_init = max(1, self.n_init // 2)
        if multi_fidelity:
            raise NotImplementedError  # todo
        # if rng_state is not None:
        #     env.config_space.seed(rng_state)
        self.seed = self.env.seed = seed
        self.ard = ard
        self.casmo = _Casmo(env.config_space,
                            n_init=self.n_init,
                            max_evals=self.max_iters,
                            batch_size=None,  # this will be updated later. batch_size=None signifies initialisation
                            verbose=verbose,
                            ard=ard,
                            acq=acq,
                            # prior_surrogate=previous_surrogate,
                            # prior_width=prior_width,
                            use_standard_gp=self.use_standard_gp,
                            time_varying=time_varying)
        self.X_init = None
        self.use_reward = use_reward
        # save the RL learning trajectory for each run of the BO
        self.trajectories = []
        self.f = obj_func if obj_func is not None else self._obj_func_handle

    def restart(self):
        self.casmo._restart()
        self.casmo._X = np.zeros((0, self.casmo.dim))
        self.casmo._fX = np.zeros((0, 1))
        self.X_init = np.array([self.env.config_space.sample_configuration().get_array() for _ in range(self.n_init)])

    def suggest(self, n_suggestions=1, ):
        if self.casmo.batch_size is None:  # Remember the batch size on the first call to suggest
            # self.batch_size = n_suggestions
            self.casmo.batch_size = n_suggestions
            # self.bo.failtol = np.ceil(np.max([4.0 / self.batch_size, self.dim / self.batch_size]))
            self.casmo.n_init = max([self.casmo.n_init, self.batch_size])
            self.restart()

        X_next = np.zeros((n_suggestions, self.dim))

        # Pick from the initial points
        n_init = min(len(self.X_init), n_suggestions)
        if n_init > 0:
            X_next[:n_init] = deepcopy(self.X_init[:n_init, :])
            self.X_init = self.X_init[n_init:, :]  # Remove these pending points

        # Get remaining points from TuRBO
        n_adapt = n_suggestions - n_init
        if n_adapt > 0:
            if len(self.casmo._X) > 0:  # Use random points if we can't fit a GP
                X = deepcopy(self.casmo._X)
                fX = copula_standardize(deepcopy(self.casmo._fX).ravel())  # Use Copula
                X_next[-n_adapt:, :] = self.casmo._create_and_select_candidates(X, fX,
                                                                                length_cont=self.casmo.length,
                                                                                length_cat=self.casmo.length_cat,
                                                                                n_training_steps=100,
                                                                                hypers={}, )[-n_adapt:, :, ]
        suggestions = X_next
        return suggestions

    def suggest_conditional_on_fixed_dims(self, fixed_dims, fixed_vals, n_suggestions=1):
        """Suggest points based on BO surrogate, conditional upon some fixed dims and values"""
        assert len(fixed_vals) == len(fixed_dims)
        X = deepcopy(self.casmo._X)
        fX = copula_standardize(deepcopy(self.casmo._fX).ravel())  # Use Copula
        X_next = self.casmo._create_and_select_candidates(X, fX,
                                                          length_cont=self.casmo.length,
                                                          length_cat=self.casmo.length_cat,
                                                          n_training_steps=100,
                                                          frozen_dims=fixed_dims,
                                                          frozen_vals=fixed_vals,
                                                          batch_size=n_suggestions,
                                                          hypers={}, )
        return X_next

    def observe(self, X, y, t=None):
        """Send an observation of a suggestion back to the optimizer.

        Parameters
        ----------
        X : list of dict-like
            Places where the objective function has already been evaluated.
            Each suggestion is a dictionary where each key corresponds to a
            parameter being optimized.
        y : array-like, shape (n,)
            Corresponding values where objective has been evaluated
        t: array-like, shape (n, )
            Corresponding to the timestep vector of t
        """
        assert len(X) == len(y)
        if t is not None:
            assert len(t) == len(y)
        XX = X
        yy = np.array(y)[:, None]
        tt = np.array(t)[:, None] if t is not None else None

        if len(self.casmo._fX) >= self.casmo.n_init:
            self.casmo._adjust_length(yy)

        self.casmo.n_evals += self.batch_size
        self.casmo._X = np.vstack((self.casmo._X, deepcopy(XX)))
        self.casmo._fX = np.vstack((self.casmo._fX, deepcopy(yy.reshape(-1, 1))))
        self.casmo.X = np.vstack((self.casmo.X, deepcopy(XX)))
        self.casmo.fX = np.vstack((self.casmo.fX, deepcopy(yy.reshape(-1, 1))))
        if tt is not None:
            self.casmo._t = np.vstack((self.casmo._t, deepcopy(tt.reshape(-1, 1))))
            self.casmo.t = np.vstack((self.casmo.t, deepcopy(tt.reshape(-1, 1))))

        # Check for a restart
        if self.casmo.length <= self.casmo.length_min:
            self.restart()

    def run(self):
        self.cur_iters = 0
        self.res = pd.DataFrame(np.nan, index=np.arange(self.max_iters + self.batch_size),
                                columns=['Index', 'LastValue', 'BestValue', 'Time'])
        self.X, self.y = [], []
        while self.cur_iters < self.max_iters:
            logging.info(f'Current iter = {self.cur_iters + 1} / {self.max_iters}')
            start = time.time()
            suggested_config_arrays = self.suggest(self.batch_size)
            # convert suggestions from np array to a valid configuration.
            suggested_configs = [CS.Configuration(self.env.config_space, vector=array) for array in
                                 suggested_config_arrays]
            rewards = self.f(suggested_configs)
            self.X += suggested_configs
            self.y += rewards
            # self.cur_iters += len(suggested_configs)
            if isinstance(rewards, float):
                rewards = np.array(rewards).reshape(1)  # to give a len to a singleton reward result
            # print(suggested_config_arrays, rewards)
            self.observe(suggested_config_arrays, rewards)
            end = time.time()
            if len(self.y):
                if self.batch_size == 1:
                    self.res.iloc[self.cur_iters, :] = [self.cur_iters, float(self.y[-1]),
                                                        float(np.min(self.y[:self.cur_iters + 1])),
                                                        end - start]
                else:
                    for j in range(self.cur_iters, self.cur_iters + self.batch_size):
                        self.res.iloc[j, :] = [j, float(self.y[j]), float(np.min(self.y[:j + 1])), end - start]
                argmin = np.argmin(self.y[:self.cur_iters + 1])

                logging.info(
                    # f'Last X {suggested_configs[0]}; '
                    f'fX={rewards}.'
                    # f' X_best={self.X[argmin]} '
                    f'fX_best={self.y[argmin]}'
                )
                if self.cur_iters % self.log_interval == 0:
                    if self.log_dir is not None:
                        logging.info(f'Saving intermediate results to {os.path.join(self.log_dir, "stats.pkl")}')
                        self.res.to_csv(os.path.join(self.log_dir, 'stats-pandas.csv'))
                        pickle.dump([self.X, self.y], open(os.path.join(self.log_dir, 'stats.pkl'), 'wb'))
                        pickle.dump(self.trajectories, open(os.path.join(self.log_dir, 'trajectories.pkl'), 'wb'))
            self.cur_iters += self.batch_size

        return self.X, self.y

    def _obj_func_handle(self, config: list, ) -> list:
        """use_synthetic: use the sklearn data generation to generate synthetic functions. """
        # if budget == 0:  # seems like budget=0 is the default option when calling the handle internally from SMAC3?
        trajectories = self.env.train_batch(config, exp_idx_start=self.cur_iters,
                                            nums_timesteps=[self.max_timesteps] * len(config),
                                            seeds=[self.seed] * len(config),

                                            )
        self.trajectories += trajectories
        reward = [-get_reward_from_trajectory(np.array(t['y']), use_last_fraction=self.use_reward) for t in
                  trajectories]
        return reward

    def get_surrogate(self, current_tr_only=False):
        """Return the surrogate GP fitted on all the training data"""
        if not self.casmo.fX.shape[0]:
            raise ValueError("Casmo does not currently have any observation data!")
        if current_tr_only:
            # the _X and _fX only store the data collected since the last TR restart and got cleared every time after a restart.
            X = deepcopy(self.casmo._X)
            y = deepcopy(self.casmo._fX).flatten()
        else:
            X = deepcopy(self.casmo.X)
            y = deepcopy(self.casmo.fX).flatten()

        # if not explicitly specified, ARD is only enabled when we have a reasonable amount of training data.
        if self.ard in [True, False]:
            ard = self.ard
        else:
            ard = True if y.shape[0] > 150 else False  # turn on ARD only when there are many data
        if len(X) < self.casmo.min_cuda:
            device, dtype = torch.device("cpu"), torch.float32
        else:
            device, dtype = self.casmo.device, self.casmo.dtype
        with gpytorch.settings.max_cholesky_size(MAX_CHOLESKY_SIZE):
            X_torch = torch.tensor(X).to(device=device, dtype=dtype)
            # X_torch_nan_mask = torch.isnan(X_torch).to(device=device)
            # here we replace the nan values with zero, but record the nan locations via the X_torch_nan_mask
            # X_torch[X_torch_nan_mask] = 0.
            y_torch = torch.tensor(y).to(device=device, dtype=dtype)
            y_torch += torch.randn(y_torch.size()) * 1e-5  # add some noise to improve numerical stability
            gp = train_gp(
                configspace=self.casmo.cs,
                train_x=X_torch,
                # train_x_mask=X_torch_nan_mask,
                train_y=y_torch,
                use_ard=ard,
                num_steps=100,
                noise_variance=None
            )
        return gp


class _Casmo:
    """A private class adapted from the TurBO code base"""

    def __init__(self, cs: CS.ConfigurationSpace,
                 n_init,
                 max_evals,
                 batch_size: int = None,
                 verbose: bool = True,
                 ard='auto',
                 acq: str = 'ei',
                 beta: float = None,
                 prior_surrogate=None,
                 time_varying: bool = False,
                 use_standard_gp: bool = False,
                 **kwargs):
        # some env parameters
        assert max_evals > 0 and isinstance(max_evals, int)
        assert n_init > 0 and isinstance(n_init, int)
        # assert batch_size > 0 and isinstance(batch_size, int)
        if DEVICE == "cuda":
            assert torch.cuda.is_available(), "can't use cuda if it's not available"
        self.cs = cs
        self.dim = len(cs.get_hyperparameters())
        self.batch_size = batch_size
        self.verbose = verbose
        self.use_ard = ard

        self.acq = acq
        self.kwargs = kwargs
        self.n_init = n_init

        # whether use pi-EI warm start
        self.prior_surrogate = prior_surrogate
        self.beta = beta if beta is not None else max_evals / 10.
        # if prior_width is None:
        #     prior_width = 0.25
        # if self.prior_surrogate is not None:
        #     self.prior = get_prior(self.prior_surrogate, prior_width=prior_width)
        # else:
        #     self.prior = None
        self.time_varying = time_varying

        # hyperparameters
        self.mean = np.zeros((0, 1))
        self.signal_var = np.zeros((0, 1))
        self.noise_var = np.zeros((0, 1))
        self.lengthscales = np.zeros((0, self.dim)) if self.use_ard else np.zeros((0, 1))
        self.n_restart = 3  # number of restarts for each acquisition optimization

        # tolerances and counters
        self.n_cand = kwargs['n_cand'] if 'n_cand' in kwargs.keys() else min(100 * self.dim, 5000)
        self.use_standard_gp = use_standard_gp
        self.n_evals = 0

        if use_standard_gp:  # this in effect disables any trust region
            logging.info('Initializing a standard GP without trust region or interleaved acquisition search.')
            self.tr_multiplier = 1.
            self.failtol = 100000
            self.succtol = 100000
            self.length_min = self.length_min_cat = -1
            self.length_max = self.length_max_cat = 100000
            self.length_init_cat = self.length_init = 100000

        else:
            self.tr_multiplier = kwargs['multiplier'] if 'multiplier' in kwargs.keys() else 1.5
            self.failtol = kwargs['failtol'] if 'failtol' in kwargs.keys() else 10
            self.succtol = kwargs['succtol'] if 'succtol' in kwargs.keys() else 3

            # Trust region sizes for continuous/int and categorical dimension
            self.length_min = kwargs['length_min'] if 'length_min' in kwargs.keys() else 0.15
            self.length_max = kwargs['length_max'] if 'length_max' in kwargs.keys() else 1.
            self.length_init = kwargs['length_init'] if 'length_init' in kwargs.keys() else .4

            self.length_min_cat = kwargs['length_min_cat'] if 'length_min_cat' in kwargs.keys() else 0.1
            self.length_max_cat = kwargs['length_max_cat'] if 'length_max_cat' in kwargs.keys() else 1.
            self.length_init_cat = kwargs['length_init_cat'] if 'length_init_cat' in kwargs.keys() else 1.

        # Save the full history
        self.X = np.zeros((0, self.dim))
        self.fX = np.zeros((0, 1))
        # timestep: in case the GP surrogate is time-varying
        self.t = np.zeros((0, 1))

        # Device and dtype for GPyTorch
        self.min_cuda = MIN_CUDA
        self.dtype = torch.float64
        self.device = torch.device("cuda") if DEVICE == "cuda" else torch.device("cpu")
        if self.verbose:
            print("Using dtype = %s \nUsing device = %s" % (self.dtype, self.device))
            sys.stdout.flush()

        self._restart()

    def _restart(self):
        self._X = np.zeros((0, self.dim))
        self._fX = np.zeros((0, 1))
        self._t = np.zeros((0, 1))
        self.failcount = 0
        self.succcount = 0
        self.length = self.length_init
        self.length_cat = self.length_init_cat

    def _adjust_length(self, fX_next):
        # print(fX_next, self._fX)
        if np.min(fX_next) <= np.min(self._fX) - 1e-3 * math.fabs(np.min(self._fX)):
            self.succcount += self.batch_size
            self.failcount = 0
        else:
            self.succcount = 0
            self.failcount += self.batch_size

        if self.succcount == self.succtol:  # Expand trust region
            self.length = min([self.tr_multiplier * self.length, self.length_max])
            # For the Hamming distance-bounded trust region, we additively (instead of multiplicatively) adjust.
            self.length_cat = min(self.length_cat * self.tr_multiplier, self.length_max_cat)
            # self.length = min(self.length * 1.5, self.length_max)
            self.succcount = 0
            logging.info(f'Expanding TR length to {self.length}')
        elif self.failcount == self.failtol:  # Shrink trust region
            # self.length = max([self.length_min, self.length / 2.0])
            self.failcount = 0
            # Ditto for shrinking.
            self.length_cat = max(self.length_cat / self.tr_multiplier, self.length_min_cat)
            self.length = max(self.length / self.tr_multiplier, self.length_min)
            logging.info(f'Shrinking TR length to {self.length}')

    def _create_and_select_candidates(self, X, fX, length_cat, length_cont,
                                      x_center=None,
                                      n_training_steps=100,
                                      hypers={}, return_acq=False,
                                      time_varying=None,
                                      t=None, batch_size=None,
                                      frozen_vals: list = None,
                                      frozen_dims: List[int] = None):
        d = X.shape[1]
        time_varying = time_varying if time_varying is not None else self.time_varying
        # assert X.min() >= 0.0 and X.max() <= 1.0
        # Figure out what device we are running on
        if batch_size is None:
            batch_size = self.batch_size
        if self.use_ard in [True, False]:
            ard = self.use_ard
        else:
            ard = True if fX.shape[0] > 150 else False  # turn on ARD only when there are many data
        if len(X) < self.min_cuda:
            device, dtype = torch.device("cpu"), torch.float32
        else:
            device, dtype = self.device, self.dtype
        with gpytorch.settings.max_cholesky_size(MAX_CHOLESKY_SIZE):
            X_torch = torch.tensor(X).to(device=device, dtype=dtype)
            # X_torch_nan_mask = torch.isnan(X_torch).to(device=device)
            # here we replace the nan values with zero, but record the nan locations via the X_torch_nan_mask
            # X_torch[X_torch_nan_mask] = 0.
            y_torch = torch.tensor(fX).to(device=device, dtype=dtype)
            y_torch += torch.randn(y_torch.size()) * 1e-5  # add some noise to improve numerical stability

            gp = train_gp(
                configspace=self.cs,
                train_x=X_torch,
                # train_x_mask=X_torch_nan_mask,
                train_y=y_torch,
                use_ard=ard,
                num_steps=n_training_steps,
                hypers=hypers,
                noise_variance=self.kwargs['noise_variance'] if
                'noise_variance' in self.kwargs else None,
                time_varying=time_varying and t is not None,
                train_t=t,
                verbose=self.verbose
            )
            # Save state dict
            hypers = gp.state_dict()

        # we are always optimizing the acquisition function at the latest timestep
        t_center = t.max() if time_varying else None
        # if x_center is None:
        #     if time_varying:
        #         eps = gp.covar_module.base_kernel.time_kernel.epsilon.detach().numpy().item()
        #         fX_penalty = (1. - eps) ** ((t_center - t) / 2) * fX
        #         x_center = X[fX_penalty.argmin().item(), :][None, :]
        #         logging.info(f'x_center={x_center}. eps={eps}')
        #     else:
        #         x_center = X[fX.argmin().item(), :][None, :]
        # else:
        #     logging.info(f'x_center={x_center}')

        def _ei(X, augmented=False):
            """Expected improvement (with option to enable augmented EI).
            This implementation assumes the objective function should be MINIMIZED, and the acquisition function should
                also be MINIMIZED (hence negative sign on both the GP prediction and the acquisition function value)
            """
            from torch.distributions import Normal
            if not isinstance(X, torch.Tensor):
                X = torch.tensor(X, dtype=dtype)
            if X.dim() == 1:
                X = X.reshape(1, -1)
            gauss = Normal(torch.zeros(1), torch.ones(1))
            # flip for minimization problems
            gp.eval()
            if time_varying:
                X = torch.hstack([t_center * torch.ones((X.shape[0], 1)), X])
            preds = gp(X)
            with gpytorch.settings.fast_pred_var():
                mean, std = -preds.mean, preds.stddev
            # use in-fill criterion
            # mu_star = -gp.likelihood(gp(torch.tensor(x_center[0].reshape(1, -1), dtype=dtype))).mean
            mu_star = -fX.min()

            u = (mean - mu_star) / std
            ucdf = gauss.cdf(u)
            updf = torch.exp(gauss.log_prob(u))
            ei = std * updf + (mean - mu_star) * ucdf
            if augmented:
                sigma_n = gp.likelihood.noise
                ei *= (1. - torch.sqrt(torch.clone(sigma_n)) / torch.sqrt(sigma_n + std ** 2))
            # incorporate the prior GP here, if applicable
            # if self.prior is not None:
            #     # print(self.prior.log_prob(X), (self.beta / fX.shape[0]))
            #     ei *= torch.exp(self.prior.log_prob(X)) ** (self.beta / fX.shape[0])
            return -ei

        def _lcb(X, beta=3.):
            if not isinstance(X, torch.Tensor):
                X = torch.tensor(X, dtype=dtype)
            if X.dim() == 1:
                X = X.reshape(1, -1)
            if time_varying:
                X = torch.hstack([t_center * torch.ones((X.shape[0], 1)), X])
            gp.eval()
            gp.likelihood.eval()
            preds = gp.likelihood(gp(X))
            # print(X, )
            with gpytorch.settings.fast_pred_var():
                # try:
                mean, std = preds.mean, preds.stddev
                lcb = mean - beta * std
                # very rarely gpytorch throws an error here when trying to infer the posterior variance/stddev.
                # Add a protection try-except block to make sure this does not stop the entire program.
                # except RuntimeError as e:
                #     logging.warning(f'Computing posterior failed with error message={e}')
                #     lcb = preds.mean
            return lcb

        if batch_size == 1:
            # Sequential setting
            if self.use_standard_gp:
                X_next, acq_next = grad_search(self.cs, x_center[0] if x_center is not None else None, eval(f'_{self.acq}'),
                                               n_restart=self.n_restart, batch_size=batch_size,
                                               verbose=self.verbose,
                                               dtype=dtype)
            else:
                X_next, acq_next = interleaved_search(self.cs,
                                                      d,
                                                      x_center[0] if x_center is not None else None,
                                                      eval(f'_{self.acq}'),
                                                      max_dist_cat=length_cat,
                                                      max_dist_cont=length_cont,
                                                      cont_int_lengthscales=gp.covar_module.base_kernel.lengthscale.cpu().detach().numpy().ravel(),
                                                      n_restart=self.n_restart, batch_size=batch_size,
                                                      verbose=self.verbose,
                                                      frozen_dims=frozen_dims,
                                                      frozen_vals=frozen_vals,
                                                      dtype=dtype)

        else:
            # batch setting: for these, we use the fantasised points {x, y}
            X_next = torch.tensor([], dtype=dtype, device=device)
            acq_next = np.array([])
            for p in range(batch_size):
                x_center_ = deepcopy(x_center[0]) if x_center is not None else None
                if self.use_standard_gp:
                    x_next, acq = grad_search(self.cs, x_center_, eval(f'_{self.acq}'),
                                              n_restart=self.n_restart, batch_size=1,
                                              dtype=dtype)
                else:
                    x_next, acq = interleaved_search(self.cs,
                                                     d,
                                                     x_center_,
                                                     eval(f'_{self.acq}'),
                                                     max_dist_cat=length_cat,
                                                     max_dist_cont=length_cont,
                                                     cont_int_lengthscales=gp.covar_module.base_kernel.lengthscale.cpu().detach().numpy().ravel(),
                                                     frozen_dims=frozen_dims,
                                                     frozen_vals=frozen_vals,
                                                     n_restart=self.n_restart, batch_size=1, dtype=dtype)

                x_next_torch = torch.tensor(x_next, dtype=dtype, device=device)
                if time_varying:
                    x_next_torch = x_next_torch[:, 1:]  # strip the time dimension

                # x_next_torch_nan_mask = torch.isnan(x_next_torch)
                # x_next_torch[x_next_torch_nan_mask] = 0.
                # The fantasy point is filled by the posterior mean of the Gaussian process.
                y_next = gp(x_next_torch).mean.detach()
                with gpytorch.settings.max_cholesky_size(MAX_CHOLESKY_SIZE):
                    X_torch = torch.cat((X_torch, x_next_torch), dim=0)
                    # X_torch_nan_mask = torch.cat((X_torch_nan_mask, x_next_torch_nan_mask), dim=0)
                    y_torch = torch.cat((y_torch, y_next), dim=0)
                    gp = train_gp(
                        configspace=self.cs,
                        train_x=X_torch, train_y=y_torch, use_ard=ard, num_steps=n_training_steps,
                        # train_x_mask=X_torch_nan_mask,
                        hypers=hypers,
                        noise_variance=self.kwargs['noise_variance'] if
                        'noise_variance' in self.kwargs else None,
                        time_varying=self.time_varying,
                        train_t=t,
                    )
                X_next = torch.cat((X_next, x_next_torch), dim=0)
                # X_next_nan_mask = torch.cat((X_next_nan_mask, x_next_torch_nan_mask), dim=0)
                acq_next = np.hstack((acq_next, acq))

        # Remove the torch tensors
        del X_torch, y_torch, gp
        X_next = np.array(X_next)
        if return_acq:
            return X_next, acq_next
        return X_next
